import torch
from torch import nn


class RelationEBMOld(nn.Module):
    """Concept EBM for arbitrary relations."""

    def __init__(self):
        """Initialize layers."""
        super().__init__()
        self.g_net = nn.Sequential(
            nn.Linear(8, 128),
            nn.LeakyReLU(),
            nn.Linear(128, 128),
            nn.LeakyReLU(),
            nn.Linear(128, 128),
            nn.LeakyReLU(),
            nn.Linear(128, 1)
        )

    def forward(self, inputs):
        """
        Forward pass, centers & sizes are tensors, relations is a list.
        Inputs:
            sbox_centers (tensor): centers of subject boxes (B, N_rel, 2)
            sbox_sizes (tensor): sizes of subject boxes (B, N_rel, 2)
            obox_centers (tensor): centers of object boxes (B, N_rel, 2)
            obox_sizes (tensor): sizes of object boxes (B, N_rel, 2)
            rels (tensor): relation ids, (B, N_rel)
        """
        sbox_centers, sbox_sizes, obox_centers, obox_sizes, _ = inputs
        # Embed object boxes to feature vectors
        subjs = torch.cat((
            sbox_centers - sbox_sizes / 2,
            sbox_centers + sbox_sizes / 2
        ), -1)
        objs = torch.cat((
            obox_centers - obox_sizes / 2,
            obox_centers + obox_sizes / 2
        ), -1)
        feats = torch.cat((
            subjs - objs,
            subjs - objs[..., (2, 3, 0, 1)]
        ), -1)

        # Compute energy
        return self.g_net(feats)


class ConceptEBMShapeOrder(nn.Module):
    """Concept EBM for arbitrary shapes."""

    def __init__(self):
        """Initialize layers."""
        super().__init__()
        self.g_net = nn.Sequential(
            nn.Linear(6, 128),
            nn.LeakyReLU(),
            nn.Linear(128, 128)
        )
        layer = nn.TransformerEncoderLayer(128, 4, 128)
        self.t_net = nn.TransformerEncoder(layer, 1)
        self.f_net = nn.Linear(128, 1)

    def forward(self, box_centers, box_sizes, fixed_args):
        """
        Forward pass, centers & sizes are tensors, relations is a list.
        Inputs:
            box_centers ([tensor]): (B, N, 2)
            box_sizes ([tensor]): (B, N, 2)
            fixed_args (list):
                centers (tensor): center of the shape (B, 2), (x, y)
                lengths (tensor): "length" of the shape (B,)
        """
        # Parse signature
        box_centers = box_centers[0]
        box_sizes = box_sizes[0]
        centers, lengths = fixed_args
        # Detect padding objects
        src_key_padding_mask = box_sizes.sum(-1) < 1e-8  # padding boxes
        num_boxes = (1 - src_key_padding_mask.float()).sum(1).long()
        # Difference with center
        feats = box_centers - centers.unsqueeze(1)  # (B, N, 2)
        feats = feats / lengths[:, None, None]
        # Consecutive pairwise differences
        forward_inds = torch.as_tensor([
            list(range(1, n)) + [0] + list(range(n, box_centers.size(1)))
            for n in num_boxes
        ]).to(feats.device)
        forward_feats = torch.stack([
            torch.gather(feats[..., 0], 1, forward_inds),
            torch.gather(feats[..., 1], 1, forward_inds)
        ], 2)
        backward_inds = torch.as_tensor([
            [n-1] + list(range(0, n-1)) + list(range(n, box_centers.size(1)))
            for n in num_boxes
        ]).to(feats.device)
        backward_feats = torch.stack([
            torch.gather(feats[..., 0], 1, backward_inds),
            torch.gather(feats[..., 1], 1, backward_inds)
        ], 2)
        feats = torch.cat([
            feats,
            feats - forward_feats,
            feats - backward_feats
        ], 2)  # (B, N, 6)
        # Compute energy
        feats = self.g_net(feats)
        feats = feats.transpose(0, 1)
        feats = self.t_net(feats, src_key_padding_mask=src_key_padding_mask)
        feats = feats.transpose(0, 1)
        return (
            self.f_net(feats)
            * (1 - src_key_padding_mask.float()).unsqueeze(-1)
        ).sum(1)


class RelationEBM(nn.Module):
    """Concept EBM for arbitrary relations."""

    def __init__(self, n_relations=6):
        """Initialize layers."""
        super().__init__()
        # self.rel_embs = nn.Embedding(n_relations, n_relations)
        # self.rel_embs.weight = nn.Parameter(torch.eye(n_relations))
        # self.rel_embs.requires_grad_(False)
        self.g_net = nn.Sequential(
            # nn.Linear(8 + n_relations, 128),
            nn.Linear(8, 128),
            nn.LeakyReLU(),
            nn.Linear(128, 128),
            nn.LeakyReLU(),
            nn.Linear(128, 128),
            nn.LeakyReLU(),
            nn.Linear(128, 1)
        )

    def forward(self, inputs):
        """
        Forward pass, centers & sizes are tensors, relations is a list.

        Inputs:
            sbox_centers (tensor): centers of subject boxes (B, N_rel, 2)
            sbox_sizes (tensor): sizes of subject boxes (B, N_rel, 2)
            obox_centers (tensor): centers of object boxes (B, N_rel, 2)
            obox_sizes (tensor): sizes of object boxes (B, N_rel, 2)
            rels (tensor): relation ids, (B, N_rel)
        """
        sbox_centers, sbox_sizes, obox_centers, obox_sizes = inputs
        # Embed object boxes to feature vectors
        subjs = torch.cat((
            sbox_centers - sbox_sizes / 2,
            sbox_centers + sbox_sizes / 2
        ), -1)
        objs = torch.cat((
            obox_centers - obox_sizes / 2,
            obox_centers + obox_sizes / 2
        ), -1)
        feats = torch.cat((
            subjs - objs,
            subjs - objs[..., (2, 3, 0, 1)]
        ), -1)
        # Relationships
        # rel_feats = self.rel_embs(rels)
        # feats = torch.cat([feats, rel_feats], -1)
        # Compute energy
        return self.g_net(feats)


class ShapeEBM(nn.Module):
    """Concept EBM for arbitrary relations."""

    def __init__(self):
        """Initialize layers."""
        super().__init__()
        self.f_net = nn.Sequential(
            nn.Linear(2, 128),
            nn.LeakyReLU(),
        )
        layer = nn.TransformerEncoderLayer(128, 4, 128, batch_first=True)
        self.context = nn.TransformerEncoder(layer, 4)
        self.g_net = nn.Sequential(
            nn.Linear(128, 128),
            nn.LeakyReLU(),
            nn.Linear(128, 128),
            nn.LeakyReLU(),
            nn.Linear(128, 1)
        )
        self.rels = nn.Embedding(5, 16)

    def forward(self, inputs):
        """
        Forward pass, centers & sizes are tensors, relations is a list.

        Inputs:
            sbox_centers (tensor): centers of subject boxes (B, N, 3)
            sbox_sizes (tensor): sizes of subject boxes (B, N, 3)
            mask (tensor): (B, N), 1 if real object, 0 if padding
        """
        sbox_centers, _, _, mask = inputs
        sbox_centers = sbox_centers - sbox_centers.mean(1)[:, None]
        # sq_rad = (sbox_centers.detach() ** 2).sum(-1)
        # sbox_centers = (
        #     sbox_centers
        #     / (sq_rad.mean(1)[:, None, None] + 1e-7)
        # )
        # rels = self.rels(rels)
        mask = ~mask.bool()
        # Embed object boxes to feature vectors
        # feats = self.f_net(torch.cat((
        #     sbox_centers,
        #     rels.unsqueeze(1).repeat(1, sbox_centers.size(1), 1)
        # ), -1))
        feats = self.f_net(sbox_centers)
        # Contextualize, (B, 128)
        feats = self.context(feats, src_key_padding_mask=mask).mean(1)
        # Compute energy
        return self.g_net(feats)
